
package com.jstarcraft.ai.jsat.classifiers.neuralnetwork;

import static java.lang.Math.max;
import static java.lang.Math.min;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import com.jstarcraft.ai.jsat.classifiers.CategoricalResults;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.Classifier;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.classifiers.DataPointPair;
import com.jstarcraft.ai.jsat.clustering.SeedSelectionMethods;
import com.jstarcraft.ai.jsat.clustering.SeedSelectionMethods.SeedSelection;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.linear.VecPaired;
import com.jstarcraft.ai.jsat.linear.distancemetrics.DistanceMetric;
import com.jstarcraft.ai.jsat.linear.distancemetrics.TrainableDistanceMetric;
import com.jstarcraft.ai.jsat.linear.vectorcollection.DefaultVectorCollection;
import com.jstarcraft.ai.jsat.linear.vectorcollection.VectorCollection;
import com.jstarcraft.ai.jsat.math.decayrates.DecayRate;
import com.jstarcraft.ai.jsat.math.decayrates.ExponetialDecay;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.utils.random.RandomUtil;

/**
 * Learning Vector Quantization (LVQ) is an algorithm that extends {@link SOM}
 * to take advantage of label information to perform classification. It creates
 * a number of representatives, or learning vectors, for each class. The LVs are
 * then updated iteratively to push away from the wrong class and pull closer to
 * the correct class. LVQ is equivalent to a type of 2 layer neural network.
 * 
 * @author Edward Raff
 */
public class LVQ implements Classifier, Parameterized {

    private static final long serialVersionUID = -3911765006048793222L;
    /**
     * The default number of iterations is {@value #DEFAULT_ITERATIONS}
     */
    public static final int DEFAULT_ITERATIONS = 200;
    /**
     * The default learning rate {@value #DEFAULT_LEARNING_RATE}
     */
    public static final double DEFAULT_LEARNING_RATE = 0.1;
    /**
     * The default eps distance factor between the two wining vectors
     * {@value #DEFAULT_EPS}
     */
    public static final double DEFAULT_EPS = 0.3;
    /**
     * The default scaling factor for the {@link LVQVersion#LVQ3} case is
     * {@value #DEFAULT_MSCALE}
     */
    public static final double DEFAULT_MSCALE = (0.5 - 0.1) / 2 + 0.1;
    /**
     * The default method of LVQ to use LVQ3
     */
    public static final LVQVersion DEFAULT_LVQ_METHOD = LVQVersion.LVQ3;
    /**
     * The default number of representatives per class is
     * {@value #DEFAULT_REPS_PER_CLASS}
     */
    public static final int DEFAULT_REPS_PER_CLASS = 3;
    /**
     * The default stopping distance for convergence is
     * {@value #DEFAULT_STOPPING_DIST}
     */
    public static final double DEFAULT_STOPPING_DIST = 1e-3;

    /**
     * The default seed selection method is SeedSelection.KPP
     */
    public static final SeedSelection DEFAULT_SEED_SELECTION = SeedSelection.KPP;

    private DecayRate learningDecay;
    private int iterations;
    private double learningRate;
    /**
     * The distance metric to use
     */
    protected DistanceMetric dm;
    private LVQVersion lvqVersion;
    private double eps;
    private double mScale;
    private double stoppingDist;
    private int representativesPerClass;
    /**
     * Array containing the learning vectors
     */
    protected Vec[] weights;
    /**
     * Array of the class that each learning vector represents
     */
    protected int[] weightClass;
    /**
     * Records the number of times each neuron won and was off the correct class
     * during training. Neurons that end with a count of zero wins will be ignored
     */
    protected int[] wins;
    private SeedSelectionMethods.SeedSelection seedSelection;
    /**
     * Contains the Learning vectors paired with their index in the weights array
     */
    protected VectorCollection<VecPaired<Vec, Integer>> vc;

    /**
     * Creates a new LVQ instance
     * 
     * @param dm         the distance metric to use
     * @param iterations the number of iterations to perform
     */
    public LVQ(DistanceMetric dm, int iterations) {
        this(dm, iterations, DEFAULT_LEARNING_RATE, DEFAULT_REPS_PER_CLASS);
    }

    /**
     * Creates a new LVQ instance
     * 
     * @param dm                      the distance metric to use
     * @param iterations              the number of iterations to perform
     * @param learningRate            the learning rate to use when updating
     * @param representativesPerClass the number of representatives to create for
     *                                each class
     */
    public LVQ(DistanceMetric dm, int iterations, double learningRate, int representativesPerClass) {
        this(dm, iterations, learningRate, representativesPerClass, DEFAULT_LVQ_METHOD, new ExponetialDecay());
    }

    /**
     * Creates a new LVQ instance
     * 
     * @param dm                      the distance metric to use
     * @param iterations              the number of iterations to perform
     * @param learningRate            the learning rate to use when updating
     * @param representativesPerClass the number of representatives to create for
     *                                each class
     * @param lvqVersion              the version of LVQ to use
     * @param learningDecay           the amount of decay to apply to the learning
     *                                rate
     */
    public LVQ(DistanceMetric dm, int iterations, double learningRate, int representativesPerClass, LVQVersion lvqVersion, DecayRate learningDecay) {
        setLearningDecay(learningDecay);
        setIterations(iterations);
        setLearningRate(learningRate);
        setDistanceMetric(dm);
        setLVQMethod(lvqVersion);
        setEpsilonDistance(DEFAULT_EPS);
        setMScale(DEFAULT_MSCALE);
        setSeedSelection(DEFAULT_SEED_SELECTION);
        setVecCollection(new DefaultVectorCollection<>());
        setRepresentativesPerClass(representativesPerClass);
    }

    /**
     * Copy Constructor
     * 
     * @param toCopy version to copy
     */
    protected LVQ(LVQ toCopy) {
        this(toCopy.dm.clone(), toCopy.iterations, toCopy.learningRate, toCopy.representativesPerClass, toCopy.lvqVersion, toCopy.learningDecay);
        if (toCopy.weights != null) {
            wins = Arrays.copyOf(toCopy.wins, toCopy.wins.length);
            weights = new Vec[toCopy.weights.length];
            weightClass = Arrays.copyOf(toCopy.weightClass, toCopy.weightClass.length);

            for (int i = 0; i < toCopy.weights.length; i++)
                this.weights[i] = toCopy.weights[i].clone();
        }
        setEpsilonDistance(toCopy.eps);
        setMScale(toCopy.getMScale());
        setSeedSelection(toCopy.getSeedSelection());
        if (toCopy.vc != null)
            this.vc = toCopy.vc.clone();
        setVecCollection(toCopy.vc.clone());

    }

    /**
     * When using {@link LVQVersion#LVQ3}, a 3rd case exists where up to two
     * learning vectors can be updated at the same time if they have the same class.
     * To avoid over fitting, an additional regularizing weight is placed upon the
     * learning rate for their update. THis sets the additional multiplied. It is
     * suggested to use a value in the range of [0.1, 0.5]
     * 
     * @param mScale the multiplication factor to apply to the learning vectors
     */
    public void setMScale(double mScale) {
        if (mScale <= 0 || Double.isInfinite(mScale) || Double.isNaN(mScale))
            throw new ArithmeticException("Scale factor must be a positive constant, not " + mScale);
        this.mScale = mScale;
    }

    /**
     * Returns the scale used for the LVQ 3 learning algorithm update set.
     * 
     * @return a scale used during LVQ3
     */
    public double getMScale() {
        return mScale;
    }

    /**
     * Sets the epsilon multiplier that controls the maximum distance two learning
     * vectors can be from each other in order to be updated at the same time. If
     * they are too far apart, only one can be updated. It is recommended to use a
     * value in the range [0.1, 0.3]
     * 
     * @param eps the scale factor of the maximum distance for two learning vectors
     *            to be updated at the same time
     */
    public void setEpsilonDistance(double eps) {
        if (eps <= 0 || Double.isInfinite(eps) || Double.isNaN(eps))
            throw new ArithmeticException("eps factor must be a positive constant, not " + eps);
        this.eps = eps;
    }

    /**
     * Sets the epsilon scale distance between learning vectors that may be allowed
     * to two at a time.
     * 
     * @return the scale of the allowable distance between learning vectors when
     *         updating
     */
    public double getEpsilonDistance() {
        return eps;
    }

    /**
     * Sets the learning rate of the algorithm. It should be set in accordance with
     * {@link #setLearningDecay(com.jstarcraft.ai.jsat.math.decayrates.DecayRate) }.
     * 
     * @param learningRate the learning rate to use
     */
    public void setLearningRate(double learningRate) {
        if (learningRate <= 0 || Double.isInfinite(learningRate) || Double.isNaN(learningRate))
            throw new ArithmeticException("learning rate must be a positive constant, not " + learningRate);
        this.learningRate = learningRate;
    }

    /**
     * Returns the learning rate at which to apply updates during the algorithm.
     * 
     * @return the learning rate to use
     */
    public double getLearningRate() {
        return learningRate;
    }

    /**
     * Sets the decay rate to apply to the learning rate.
     * 
     * @param learningDecay the rate to decay the learning rate
     */
    public void setLearningDecay(DecayRate learningDecay) {
        this.learningDecay = learningDecay;
    }

    /**
     * Returns the method used to decay the learning rate over each iteration
     * 
     * @return the decay rate used at each iteration
     */
    public DecayRate getLearningDecay() {
        return learningDecay;
    }

    /**
     * Sets the number of learning iterations that will occur.
     * 
     * @param iterations the number of iterations for the algorithm to use
     */
    public void setIterations(int iterations) {
        if (iterations < 0)
            throw new ArithmeticException("Can not perform a negative number of iterations");
        this.iterations = iterations;
    }

    /**
     * Returns the number of iterations of the algorithm to apply
     * 
     * @return the number of iterations to perform
     */
    public int getIterations() {
        return iterations;
    }

    /**
     * Sets the number of representatives to create for each class. It is possible
     * to have an unbalanced number of representatives per class, but that is not
     * currently supported. Increasing the number of representatives per class
     * increases the complexity of the decision boundary that can be learned.
     * 
     * @param representativesPerClass the number of representatives to create for
     *                                each class
     */
    public void setRepresentativesPerClass(int representativesPerClass) {
        this.representativesPerClass = representativesPerClass;
    }

    /**
     * Returns the number of representatives to create for each class.
     * 
     * @return the number of representatives to create for each class.
     */
    public int getRepresentativesPerClass() {
        return representativesPerClass;
    }

    /**
     * Sets the version of LVQ used.
     * 
     * @param lvqMethod the version of LVQ to use
     */
    public void setLVQMethod(LVQVersion lvqMethod) {
        this.lvqVersion = lvqMethod;
    }

    /**
     * Returns the version of the LVQ algorithm to use.
     * 
     * @return the version of the LVQ algorithm to use.
     */
    public LVQVersion getLVQMethod() {
        return lvqVersion;
    }

    /**
     * Sets the distance used for learning
     * 
     * @param dm the distance metric to use
     */
    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    /**
     * Returns the distance metric to use
     * 
     * @return the distance metric to use
     */
    public DistanceMetric getDistanceMetric() {
        return dm;
    }

    /**
     * The algorithm terminates early if the learning vectors are only moving small
     * distances. The stopping distance is the minimum distance that one of the
     * learning vectors must move for the algorithm to continue.
     * 
     * @param stoppingDist the minimum distance for each learning vector to move
     */
    public void setStoppingDist(double stoppingDist) {
        if (stoppingDist < 0 || Double.isInfinite(stoppingDist) || Double.isNaN(stoppingDist))
            throw new ArithmeticException("stopping dist must be a zero or positive constant, not " + stoppingDist);
        this.stoppingDist = stoppingDist;
    }

    /**
     * Returns the stopping distance used to terminate the algorithm early
     * 
     * @return the stopping distance used toe nd the algorithm early
     */
    public double getStoppingDist() {
        return stoppingDist;
    }

    /**
     * Sets the seed selection method used to select the initial learning vectors
     * 
     * @param seedSelection the method of initialing LVQ
     */
    public void setSeedSelection(SeedSelection seedSelection) {
        this.seedSelection = seedSelection;
    }

    /**
     * Returns the method of seed selection used
     * 
     * @return the method of seed selection used
     */
    public SeedSelection getSeedSelection() {
        return seedSelection;
    }

    /**
     * There are several LVQ versions, each one adding an additional case in which
     * two LVs instead of one can be updated.
     */
    public enum LVQVersion {
        /**
         * LVQ1 will only update one LV
         */
        LVQ1,
        /**
         * Two vectors will be updated if they are close enough together. The closest
         * was the wrong class but the 2nd closet was the correct class.
         */
        LVQ2,
        /**
         * Two vectors will be updated if they are close enough together and do not
         * belong to the same class if one of them was the correct class to a training
         * vector.
         */
        LVQ21,
        /**
         * Two vectors will be updated if they are close enough together and are of the
         * same class as the training vector.
         */
        LVQ3
    }

    /**
     * Sets the vector collection factory to use when storing the final learning
     * vectors
     * 
     * @param vcf the vector collection factory to use
     */
    public void setVecCollection(VectorCollection<VecPaired<Vec, Integer>> vcf) {
        this.vc = vcf;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(weightClass.length / representativesPerClass);

        int index = vc.search(data.getNumericalValues(), 1).get(0).getVector().getPair();
        cr.setProb(weightClass[index], 1.0);

        return cr;
    }

    /**
     * Returns true if the two distance values are within an acceptable epsilon
     * ratio of each other.
     * 
     * @param minDist  the first distance
     * @param minDist2 the second distance
     * @return <tt>true</tt> if the are acceptable close
     */
    protected boolean epsClose(double minDist, double minDist2) {
        return min(minDist / minDist2, minDist2 / minDist) > (1 - eps) && max(minDist / minDist2, minDist2 / minDist) < (1 + eps);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        TrainableDistanceMetric.trainIfNeeded(dm, dataSet, parallel);
        Random rand = RandomUtil.getRandom();
        int classCount = dataSet.getPredicting().getNumOfCategories();
        weights = new Vec[classCount * representativesPerClass];
        Vec[] weightsPrev = new Vec[weights.length];
        weightClass = new int[weights.length];
        wins = new int[weights.length];

        // Generate weights that are hopefully close to their final positions

        int curClass = 0;
        int curPos = 0;
        while (curClass < classCount) {
            List<DataPoint> origSubList = dataSet.getSamples(curClass);
            List<DataPointPair<Integer>> subList = new ArrayList<>(origSubList.size());
            for (DataPoint dp : origSubList)
                subList.add(new DataPointPair<>(dp, curClass));
            ClassificationDataSet subSet = new ClassificationDataSet(subList, dataSet.getPredicting());
            List<Vec> classSeeds = SeedSelectionMethods.selectIntialPoints(subSet, representativesPerClass, dm, rand, seedSelection);
            for (Vec v : classSeeds) {
                weights[curPos] = v.clone();
                weightsPrev[curPos] = weights[curPos].clone();
                weightClass[curPos++] = curClass;
            }
            curClass++;
        }
        Vec tmp = weights[0].clone();

        for (int iteration = 0; iteration < iterations; iteration++) {
            for (int j = 0; j < weights.length; j++)
                weights[j].copyTo(weightsPrev[j]);
            Arrays.fill(wins, 0);
            double alpha = learningDecay.rate(iteration, iterations, learningRate);
            for (int i = 0; i < dataSet.size(); i++) {
                Vec x = dataSet.getDataPoint(i).getNumericalValues();
                int closestClass = -1;
                int minDistIndx = 0, minDistIndx2 = 0;
                double minDist = Double.POSITIVE_INFINITY, minDist2 = Double.POSITIVE_INFINITY;

                for (int j = 0; j < weights.length; j++) {
                    double dist = dm.dist(x, weights[j]);
                    if (dist < minDist) {
                        if (lvqVersion == LVQVersion.LVQ2) {
                            minDist2 = minDist;
                            minDistIndx2 = minDistIndx;
                        }
                        minDist = dist;
                        minDistIndx = j;
                        closestClass = dataSet.getDataPointCategory(i);

                    }
                }

                if (lvqVersion.ordinal() >= LVQVersion.LVQ2.ordinal() && weightClass[minDistIndx] != weightClass[minDistIndx2] && closestClass == weightClass[minDistIndx2] && epsClose(minDist, minDist2)) {// Update both vectors
                                                                                                                                                                                                             // Move the closest farther away
                    x.copyTo(tmp);
                    tmp.mutableSubtract(weights[minDistIndx]);
                    weights[minDistIndx].mutableSubtract(alpha, tmp);
                    // And the 2nd closest closer
                    x.copyTo(tmp);
                    tmp.mutableSubtract(weights[minDistIndx2]);
                    weights[minDistIndx2].mutableAdd(alpha, tmp);
                    wins[minDistIndx2]++;
                } else if (lvqVersion.ordinal() >= LVQVersion.LVQ21.ordinal() && weightClass[minDistIndx] != weightClass[minDistIndx2] && closestClass == weightClass[minDistIndx] && epsClose(minDist, minDist2)) {// Update both vectors
                                                                                                                                                                                                                    // Move the closest closer
                    x.copyTo(tmp);
                    tmp.mutableSubtract(weights[minDistIndx]);
                    weights[minDistIndx].mutableAdd(alpha, tmp);
                    wins[minDistIndx]++;
                    // And the 2nd closest farther away
                    x.copyTo(tmp);
                    tmp.mutableSubtract(weights[minDistIndx2]);
                    weights[minDistIndx2].mutableSubtract(alpha, tmp);
                } else if (lvqVersion.ordinal() >= LVQVersion.LVQ3.ordinal() && weightClass[minDistIndx] == weightClass[minDistIndx2] && min(minDist / minDist2, minDist2 / minDist) > (1 - eps) * (1 + eps)) {// Update both vectors in the same direction
                    x.copyTo(tmp);
                    tmp.mutableSubtract(weights[minDistIndx]);
                    weights[minDistIndx].mutableAdd(mScale * alpha, tmp);

                    x.copyTo(tmp);
                    tmp.mutableSubtract(weights[minDistIndx2]);
                    weights[minDistIndx2].mutableAdd(mScale * alpha, tmp);
                    wins[minDistIndx]++;
                    wins[minDistIndx2]++;
                } else // Base case, can only update one vector
                {
                    x.copyTo(tmp);
                    tmp.mutableSubtract(weights[minDistIndx]);
                    if (closestClass == weightClass[minDistIndx])// Move closer to the right class
                    {
                        wins[minDistIndx]++;
                        weights[minDistIndx].mutableAdd(alpha, tmp);
                    } else// Move farther away
                    {
                        weights[minDistIndx].mutableSubtract(alpha, tmp);
                    }
                }

            }
            // Check for early convergence
            boolean stopEarly = true;
            for (int j = 0; j < weights.length; j++)
                if (stopEarly && dm.dist(weights[j], weightsPrev[j]) > stoppingDist)
                    stopEarly = false;
            if (stopEarly)
                break;
        }

        List<VecPaired<Vec, Integer>> finalLVs = new ArrayList<VecPaired<Vec, Integer>>(weights.length);
        for (int i = 0; i < weights.length; i++)
            if (wins[i] == 0)
                continue;
            else
                finalLVs.add(new VecPaired<Vec, Integer>(weights[i], i));
        vc.build(parallel, finalLVs, dm);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public LVQ clone() {
        return new LVQ(this);
    }

}
